# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Type

import gradio as gr

from swift.ui.base import BaseUI


class Sample(BaseUI):

    group = "llm_sample"

    locale_dict = {
        "sampler_type": {
            "label": {"zh": "采样类型", "en": "Sampler type"},
        },
        "sampler_engine": {
            "label": {"zh": "推理引擎", "en": "Infer engine"},
        },
        "num_return_sequences": {
            "label": {
                "zh": "采样返回的原始序列数量",
                "en": "Num of original sequences returned by sampling",
            },
        },
        "n_best_to_keep": {
            "label": {"zh": "最佳序列数量", "en": "Num of best sequences"},
        },
        "max_new_tokens": {
            "label": {"zh": "生成序列最大长度", "en": "Max new tokens"},
        },
        "temperature": {
            "label": {"zh": "采样温度", "en": "Temperature"},
        },
        "top_k": {
            "label": {"zh": "Top-k", "en": "Top-k"},
        },
        "top_p": {
            "label": {"zh": "Top-p", "en": "Top-p"},
        },
        "repetition_penalty": {
            "label": {"zh": "重复惩罚", "en": "Repetition Penalty"},
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type["BaseUI"]):
        with gr.Row():
            gr.Dropdown(
                elem_id="sampler_type",
                choices=["sample", "mcts", "distill"],
                value="sample",
                scale=5,
            )
            gr.Dropdown(
                elem_id="sampler_engine",
                choices=["pt", "lmdeploy", "vllm", "no", "client"],
                value="pt",
                scale=5,
            )
            gr.Slider(
                elem_id="num_return_sequences",
                minimum=1,
                maximum=128,
                step=1,
                value=64,
                scale=5,
            )
            gr.Slider(
                elem_id="n_best_to_keep",
                minimum=1,
                maximum=64,
                step=1,
                value=5,
                scale=5,
            )
        with gr.Row():
            gr.Textbox(elem_id="max_new_tokens", lines=1, value="2048")
            gr.Slider(
                elem_id="temperature", minimum=0.0, maximum=10, step=0.1, value=1.0
            )
            gr.Slider(elem_id="top_k", minimum=1, maximum=100, step=5, value=20)
            gr.Slider(elem_id="top_p", minimum=0.0, maximum=1.0, step=0.05, value=0.7)
            gr.Slider(
                elem_id="repetition_penalty",
                minimum=0.0,
                maximum=10,
                step=0.05,
                value=1.05,
            )
